import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from utils.parse_config import *
from collections import defaultdict

def create_modules(module_defs):
    hyperparams = module_defs.pop(0)
    output_filters = [int(hyperparams['channels'])]
    module_list = nn.ModuleList()
    for i, module_def in enumerate(module_defs):
        moudles = nn.Sequential()
        if module_def['type'] == 'convolutional':
            filters = int(module_def['filter'])
            kernel_size = int(module_def['size'])
            pad = (kernel_size-1)//2 if int(module_def['pad']) else 0
            moudles.add_module('conv_%d' % i, nn.Conv2d(in_channels=output_filters[-1],
                                                        out_channels = filters,
                                                        kernel_size=kernel_size,
                                                        stride=int(module_def['stride']),
                                                        padding=pad))
            if module_def['activation'] == 'Relu':
                moudles.add_module('Relu_%d'%i, nn.ReLU())
            if module_def['activation'] == 'SOFTMAX':
                moudles.add_module('SOFT_%d'%i, nn.Softmax2d())
            if module_def['activation'] == 'Linear':
                moudles.add_module('Line_%d'%i, nn.Linear(filters,filters))

        elif module_def['type'] == 'pooling':
            kernel_size = int(module_def['size'])
            stride = int(module_def['stride'])
            pooling = nn.MaxPool2d(kernel_size = kernel_size, stride=stride)
            moudles.add_module('Pooling_%d' %i, pooling)
        elif module_def['type'] == 'route':
            layers = [int(x) for x in module_def['layer'].split(',')]
            filters = sum([output_filters[layer_i] for layer_i in layers])
            moudles.add_module('route_%d'%i, EmptyLayer())
        elif module_def['type'] == 'shortcut':
            filters = output_filters[int(module_def['from'])]
            moudles.add_module('shortcut_%d'%i, EmptyLayer())
        elif module_def['type'] == 'linear':
            filyers = output_filters[-1]
            moudles.add_module('linear_%d' %i, nn.Linear(filters, filters))
            
        module_list.append(moudles)
        output_filters.append(filters)
    print('============================================')
    print(output_filters)

    return hyperparams, module_list


class EmptyLayer(nn.Module):
    def __init__(self):
        super(EmptyLayer, self).__init__()


class WPOD_NET(nn.Module):
    def __init__(self, config_path, img_size=416):
        super(WPOD_NET, self).__init__()
        self.module_defs = parse_model_config(config_path)
        self.hyperparams, self.module_list = create_modules(self.module_defs)
        self.img_size = img_size
        self.seen = 0
        self.header_info = np.array([0, 0, 0, self.seen, 0])
        self.loss_names = ['v1', 'v2', 'v3', 'v4', 'v5', 'v6', 'v7', 'v8']
        self.Ns = 2**4
        self.loss = 0

    def forward(self, x, target = None):
        self.losses = defaultdict(float)
        is_training = target is not None
        alpha = 7.75
        bs = x.size(0)
        wid = self.img_size//self.Ns
        hei = self.img_size//self.Ns
        is_training = target is not None
        output = []
        self.losses = defaultdict(float)
        layer_outputs = []
        FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
        for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)):
            print('================================================================================')
            print('start:' + str(i) + module_def['type'])
            if module_def['type'] in ['convolutional', 'pooling']:
                x = module(x)
                
            elif module_def['type'] == 'route':
                
                layer_i = [int(x) for x in module_def['layer'].split(',')]
                x = torch.cat([layer_outputs[i] for i in layer_i], 1)
            elif module_def['type'] == 'shortcut':
                layer_i = [int(x) for x in module_def['from'].split(',')]
                x = sum(layer_outputs[i] for i in layer_i)
            layer_outputs.append(x)
            print('finish:' + module_def['type'])
            print('================================================================================')
        # x为运行的结果, 输出为M*N*bs*6
        prediction = x.view(bs, wid, hei, 8)
        if is_training:
            q = np.array([[-0.5, -0.5, 0.5, 0.5],[-0.5, 0.5, 0.5, -0.5]])
            q = FloatTensor(q)# .repeat((bs,wid,hei,1,1))
            v1 = prediction[:,:,:,0]
            v2 = prediction[:,:,:,1]
            v3 = prediction[:,:,:,2]
            v4 = prediction[:,:,:,3].unsqueeze(3).unsqueeze(4).view(1,1,wid,hei,bs)
            v5 = prediction[:,:,:,4].unsqueeze(3).unsqueeze(4).view(1,1,wid,hei,bs)
            v6 = prediction[:,:,:,5]
            v7 = prediction[:,:,:,6].unsqueeze(3).unsqueeze(4).view(1,1,wid,hei,bs)
            v8 = prediction[:,:,:,7].unsqueeze(3).unsqueeze(4).view(1,1,wid,hei,bs)
            com = torch.zeros(prediction[:,:,:,0].shape)
            com = com.type(FloatTensor)
            v3 = torch.max(v3, com).unsqueeze(3).unsqueeze(4).view(1,1,wid,hei,bs)
            v6 = torch.max(v6, com).unsqueeze(3).unsqueeze(4).view(1,1,wid,hei,bs)

            v1 = Variable(v1.type(FloatTensor), requires_grad = True)
            v2 = Variable(v2.type(FloatTensor), requires_grad = True)
            v3 = Variable(v3.type(FloatTensor), requires_grad = True)
            v4 = Variable(v4.type(FloatTensor), requires_grad = True)
            v5 = Variable(v5.type(FloatTensor), requires_grad = True)
            v6 = Variable(v6.type(FloatTensor), requires_grad = True)
            v7 = Variable(v7.type(FloatTensor), requires_grad = True)
            v8 = Variable(v8.type(FloatTensor), requires_grad = True)

            T_mn = torch.cat([torch.cat([v3,v4],1),torch.cat([v5,v6],1)],0).permute(4,2,3,0,1).contiguous()
            # print("====================T_mn's shape is " + str(T_mn.shape))
            # print("====================q's shape is " + str(q.shape))
            # print("====================v7,v8's shape is " + str(torch.cat([v7,v8], 0).permute(4,2,3,0,1).shape))
            # print("====================multiple shape is " + str(torch.matmul(T_mn, q).shape))
            T_mn = torch.matmul(T_mn, q) + torch.cat([v7,v8], 0).permute(4,2,3,0,1).contiguous() # bs*M*N*2*4
            # target原本是一个bs*2*4的tensor
            target = target.view(bs, 2, 4)
            target_rep = target.repeat(wid, hei,1,1,1).permute(2,0,1,3,4).contiguous()
            # print("====================target_rep's shape " + str(target_rep.shape))

            loc_mn = torch.zeros(bs,wid, hei,2,1)
            for i in range(bs):
                for cur_wid in range(wid):
                    for cur_hei in range(hei):
                        loc_mn[i, cur_wid, cur_hei, 0, 0] = cur_wid
                        loc_mn[i, cur_wid, cur_hei, 1, 0] = cur_hei
            # target_rep = Variable(target_rep.type(FloatTensor), requires_grad = False)
            loc_mn = Variable(loc_mn.type(FloatTensor), requires_grad = False)
            # print("====================del's shape " + str((target_rep - loc_mn).shape))
            A_mn = (1 / alpha) * ((1 / self.Ns) * target_rep - loc_mn).view(bs, wid, hei, 2, 4)
            f_affine = (torch.norm(T_mn-A_mn,1,3).view(bs, wid, hei, 4)).sum(3)
            # print("====================f-affine's shape " + str(f_affine.shape))
            tareget_cen = (((target[:,:,0] +target[:,:,1])/2 + (target[:,:,2] + target[:,:,3])/2)/2).floor() # bs*2
            # print("=========================target_cen's shape " + str(tareget_cen.shape))
            mask_mn = torch.zeros(bs,wid,hei)
            for i in range(bs):
                cen_i = int(tareget_cen[i,0]/self.Ns)
                cen_j = int(tareget_cen[i,1]/self.Ns)
                mask_mn[i, cen_i, cen_j] = 1
            mask_mn = Variable(mask_mn.type(FloatTensor), requires_grad = False)
            f_probs = mask_mn.mul(torch.log(v1)) + (1-mask_mn).mul(torch.log(v2))
            loss = (mask_mn.mul(f_affine) + f_probs).sum(2).sum(1).sum(0)
            self.loss = loss
        # print("============================loss is " + str(loss))
        return loss if is_training else prediction

    def save_weights(self, path, cutoff = -1):
        # fp = open(path, 'wb')
        # self.header_info[3] = self.seen
        # self.header_info.tofile(fp)

        # for i, (module_def, module) in enumerate(zip(self.module_defs[:cutoff], self.module_list[:cutoff])):
        #     if module_def['type'] == 'convolutional':
        #         conv_layer = module[0]
        #         conv_layer.bias.data.cpu().numpy().tofile(fp)
        #         conv_layer.weight.data.cpu().numpy().tofile(fp)
        #     if module_def['type'] == 'linear':
        #         linear_layer = module[0]
        #         linear_layer.bias.data.cpu().numpy().tofile(fp)
        #         linear_layer.weights.data.cpu().numpy().tofile(fp)

        # fp.close()
        torch.save(self.state_dict(), path)

    def load_weights(self, weights_path):
        # fp = open(weights_path, 'rb')
        # header = np.fromfile(fp, dtype=np.int32, count=5)
        # self.header_info = header
        # self.seen = header[3]
        # weights = np.fromfile(fp, dtype=np.float32)
        # fp.close()
        # ptr = 0
        # for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)):
        #     if module_def['type'] == 'convolutional':
        #         conv_layer = module[0]
        #         num_b = conv_layer.bias.numel()
        #         conv_b = torch.from_numpy(weights[ptr:ptr + num_b]).view_as(conv_layer.bias)
        #         conv_layer.bias.data.copy_(conv_b)
        #         ptr += num_b
        #         num_w = conv_layer.weight.numel()
        #         conv_w = torch.from_numpy(weights[ptr: ptr + num_w]).view_as(conv_layer.weight)
        #         conv_layer.weight.data.copy_(conv_w)
        #         ptr += num_w

        #     if module_def['type'] == 'linear':
        #         linear_layer = module[0]
        #         num_b = linear_layer.bias.numel()
        #         linear_b = torch.from_numpy(weights[ptr: ptr+num_b]).view_as(linear_layer.bias)
        #         linear_layer.bias.data.copy_(linear_b)
        #         ptr += num_b
        #         num_w = linear_layer.weight.numel()
        #         linear_w = torch.from_numpy(weights[ptr: ptr+num_w]).view_as(linear_layer.weight)
        #         linear_layer.weight.data.copy_(linear_w)
        #         ptr += num_w
        self.load_state_dict(torch.load(weights_path))






